import os
import torch
import wandb
import pathlib
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
# from torchvision.transforms import InterpolationMode

sqrt = lambda x: int(torch.sqrt(torch.Tensor([x])))


def plot_img(output_path, epoch, x):
    spath = os.path.join(output_path, 'images')
    pathlib.Path(spath).mkdir(parents=True, exist_ok=True)
    sfile = os.path.join(spath, 'epoch_{}.png'.format(epoch))
    x_c = torch.clamp(x, -1, 1)
    save_image(x_c, sfile, normalize=True, nrow=sqrt(x_c.size(0)))
    wandb.log({'gen_img/epoch_{}.png'.format(epoch): wandb.Image(sfile)})


def plot_mix_img(args, dt, spath, x, i, j, n_pts, i_count):
    flst = list()
    if args.inex_resz > 1:
        x = resize(args, x)
    if not args.inex_no_all:
        fname_all = 'm_all_{}_{}_{}.png'.format(dt, i, j)
        sfile = os.path.join(spath, fname_all)
        x = x.squeeze()
        save_image(x, sfile, normalize=True, nrow=n_pts)
        wandb.log({'x_mix/{}'.format(fname_all): wandb.Image(sfile)})
    for k in range(n_pts):
        fname_each = 'ms_{}_{}_{}_{}_{}_{}.png'.format(dt, i, j, i_count, k,
                                                    args.inex_proj)
        sfile = os.path.join(spath, fname_each)
        x_each = x[(n_pts+1)//2][k] if args.inex_noise else x[k]
        save_image(x_each, sfile, normalize=True, nrow=n_pts)
        flst.append(fname_each)
    return flst


def resize(args, x):
    transform = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.Resize(size=args.inex_resz, #antialias=False, \
                          interpolation=Image.NEAREST),
        # transforms.ToTensor()
    ])
    x = torch.stack([transform(x_) for x_ in x], dim=0)
    return x


def resize_ori(args, x):
    transform = transforms.Compose([
        transforms.Resize(size=args.inex_resz, #antialias=False, \
                          interpolation=Image.NEAREST),
        # transforms.ToTensor()
    ])
    x = torch.stack([transform(x_) for x_ in x], dim=0)
    return x


def save_ori_img(args, dt, spath, x, n_pts, n_cls, i_iter):
    flst = list()
    x = x.reshape(-1, *x.shape[2:])
    if args.inex_resz > 1:
        x = resize_ori(args, x)
    if not args.inex_no_all:
        fname_all = 'o_all_{}_{}.png'.format(dt, i_iter)
        sfile = os.path.join(spath, fname_all)
        save_image(x, sfile, normalize=True, nrow=n_pts)
        wandb.log({'x_ori_{}/{}'.format(i_iter, fname_all): wandb.Image(sfile)})
    for i in range(n_cls):
        for j in range(n_pts):
            fname_each = 'o_{}_{}_{}_{}.png'.format(dt, i, j, i_iter)
            sfile = os.path.join(spath, fname_each)
            x_each = x[i*n_pts +j]
            save_image(x_each, sfile, normalize=True, nrow=n_pts)
            flst.append(fname_each)
    return flst


def save_mix_i_img(args, dt, spath, x_each, i, j, i_lamb, i_iter, i_count):
    flst = list()
    if args.inex_resz > 1:
        x_each = resize(args, x_each)
    fname_each = 'm_{}_{}_{}_{}_{}_{}.png'.format(dt, i, j, i_lamb, i_count, args.inex_proj)
    sfile = os.path.join(spath, fname_each)
    # x_each = x[(n_pts+1)//2][i_lamb] if args.inex_noise else x[i_lamb]
    save_image(x_each, sfile, normalize=True, nrow=1) #n_pts)
    wandb.log({'x_mix_{}/{}'.format(i_iter, fname_each): wandb.Image(sfile)})
    flst.append(fname_each)
    return flst


def img_show(imgs, resolution=128, max_img=4) :
    import cv2

    if type(imgs) == torch.Tensor :
        np_imgs = imgs.cpu().data.numpy()
    else :
        np_imgs = imgs;

    B, W, H, C = np_imgs.shape;
    if W == 3 :
        np_imgs = np_imgs.transpose((0,2,3,1))

    np_imgs = np_imgs[:max_img]

    if np_imgs.dtype == np.float32 :
        np_imgs = np.clip(np_imgs, a_min=0.0, a_max=1.0);

    np_imgs *= 255.0
    np_imgs = np_imgs.astype(np.uint8);

    tot_img = None
    for img in np_imgs :
        tot_img = img if tot_img is None else np.concatenate([tot_img, img], axis=1)

    tot_img = cv2.resize(tot_img, dsize=(resolution*len(np_imgs), resolution),
                     interpolation=cv2.INTER_CUBIC)

    cv2.imshow("test", tot_img);
    cv2.waitKey(0)


